import sys
import numpy as np
import matplotlib
import argparse
from matplotlib import pyplot as plt
from sklearn.calibration import calibration_curve

# import seaborn as sns
label_fs=12
title_fs=10
ax_fs=10
matplotlib.rcParams['xtick.labelsize'] = 12
matplotlib.rcParams['ytick.labelsize'] = 9

# ----------------------------------------------------------------------------

parser = argparse.ArgumentParser()

parser.add_argument('--bernoulli-plots', help='path to experiment results', required=True)
parser.add_argument('--adversarial-plots', help='path to experiment results', required=True)
parser.add_argument('--out', help='path to figure output', required=True)

args = parser.parse_args()

# ----------------------------------------------------------------------------


def _bucket_sizes(p, n_bins=10):
  lengths = list()
  iv_size = 1./n_bins
  for i in xrange(n_bins):
    l = len([p_j for p_j in p if i*iv_size <= p_j <= (i+1)*iv_size])
    if l:
      lengths.append(l)
  return lengths


# ----------------------------------------------------------------------------
# load data

pfb = open(args.bernoulli_plots)
pfa = open(args.adversarial_plots)

bernoulli_data = list()
for line in pfb:
	name, prob_str = line.strip().split()
	probs = [float(p) for p in prob_str.split(',')]
	bernoulli_data.append((name, probs))

adversarial_data = list()
for line in pfa:
	name, prob_str = line.strip().split()
	probs = [float(p) for p in prob_str.split(',')]
	adversarial_data.append((name, probs))

# ----------------------------------------------------------------------------
# generate figure

matplotlib.rcParams.update({'font.size': 7})
plt.figure(figsize=(6,3))

plt.subplot(221)
plt.title('(a) Bernoulli setting: accuracy', fontsize=title_fs)
# plt.subplot2grid((7,2),(0,0),rowspan=3)
T = len(bernoulli_data[0][1])
plt.plot(range(T), bernoulli_data[1][1], label='Uncalibrated')
plt.plot(range(T), bernoulli_data[2][1], label='Recalibrated')
plt.plot(range(T), bernoulli_data[0][1], label='Only calibrated')
plt.ylabel('L2 loss', fontsize=ax_fs)
plt.legend(loc="upper right", fontsize=6)
# plt.xlabel('Time')


plt.subplot(222)
plt.title('(b) Bernoulli setting: calibration', fontsize=title_fs)
# plt.subplot2grid((7,2),(0,0),rowspan=3)
T = len(bernoulli_data[0][1])
plt.plot(range(T), bernoulli_data[4][1])
plt.plot(range(T), bernoulli_data[5][1])
plt.plot(range(T), bernoulli_data[3][1])
plt.ylabel('Cal. error', fontsize=ax_fs)
plt.ylim([0,0.07])
# plt.xlabel('Time', fontsize=label_fs)

plt.subplot(223)
plt.title('(c) Adversarial setting: accuracy', fontsize=title_fs)
# plt.subplot2grid((7,2),(0,0),rowspan=3)
T = len(adversarial_data[0][1])
plt.plot(range(T), adversarial_data[1][1])
plt.plot(range(T), adversarial_data[2][1])
plt.plot(range(T), adversarial_data[0][1])
plt.ylim([0.0,0.6])
plt.ylabel('L2 loss', fontsize=ax_fs)
# plt.xlabel('Time')
plt.xlabel('Time', fontsize=label_fs)


plt.subplot(224)
plt.title('(d) Adversarial setting: calibration', fontsize=title_fs)
# plt.subplot2grid((7,2),(0,0),rowspan=3)
T = len(adversarial_data[0][1])
plt.plot(range(T), adversarial_data[4][1])
plt.plot(range(T), adversarial_data[5][1])
plt.plot(range(T), adversarial_data[3][1])
plt.ylim([0,0.25])
plt.ylabel('Cal. error', fontsize=ax_fs)
plt.xlabel('Time', fontsize=label_fs)

plt.tight_layout()
# plt.show()
plt.savefig(args.out, bbox_inches='tight')
